import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch.nn import LayerNorm
from models.conv import Conv2d, FullyConvolution
from models.i3d_model2 import I3D
from models.rnn import DynamicGRU
from models.info_nce import InfoNCE
from clip import clip as Clip
import os
from models.prompt import PromptLearner, TextEncoder, CrossAttention
#from models.whitening import SAW
from models.select_object_norm import Select_object

#from models.ot import *


def load_clip_to_cpu():
    backbone_name = "ViT-B/32"
    url = Clip._MODELS[backbone_name]
    model_path = Clip._download(url, root=os.path.expanduser("~/.cache/clip"))

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cuda").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cuda")

    model = Clip.build_model(state_dict or model.state_dict())
    return model

# def GOT(v_, q_):
#     cos_distance = cost_matrix_batch_torch(v_.transpose(2, 1), q_.transpose(2, 1))
#     cos_distance = cos_distance.transpose(1, 2)
#     beta = 0.1
#     min_score = cos_distance.min()
#     max_score = cos_distance.max()
#     threshold = min_score + beta * (max_score - min_score)
#     cos_dist = torch.nn.functional.relu(cos_distance - threshold)
#
#     wd = - IPOT_distance_torch_batch_uniform(cos_dist, v_.size(0), v_.size(1), q_.size(1), 30)
#
#     return wd

class StyleRandomization(nn.Module):
    def __init__(self, eps=1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x, K = 10):
        N, C, H = x.size()

        if self.training:
            x = x.view(N, C, -1)
            mean = x.mean(-1, keepdim=True)
            var = x.var(-1, keepdim=True)

            x = (x - mean) / (var + self.eps).sqrt()

            idx_swap = torch.randperm(N)
            alpha = torch.rand(N, 1, 1) / K
            if x.is_cuda:
                alpha = alpha.cuda()
            mean = (1 - alpha) * mean + alpha * mean[idx_swap]
            var = (1 - alpha) * var + alpha * var[idx_swap]

            x = x * (var + self.eps).sqrt() + mean
            x = x.view(N, C, H)

        return x


class StyleRandomization_3D(nn.Module):
    def __init__(self, eps=1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x, K = 1):
        N, C, T, H, W = x.size()
        idx_swap = torch.randperm(N)

        if self.training:
            x_stack = []
            for i in range(T):
                x_i = x[:,:,i,:,:]
                x_i = x_i.view(N, C, -1)
                mean = x_i.mean(-1, keepdim=True)
                var = x_i.var(-1, keepdim=True)

                x_i = (x_i - mean) / (var + self.eps).sqrt()

                alpha = torch.rand(N, 1, 1) / K
                if x_i.is_cuda:
                    alpha = alpha.cuda()
                mean = (1 - alpha) * mean + alpha * mean[idx_swap]
                var = (1 - alpha) * var + alpha * var[idx_swap]

                x_i = x_i * (var + self.eps).sqrt() + mean
                x_i = x_i.view(N, C, H, W)
                x_stack.append(x_i)
            x_stack = torch.stack(x_stack,dim=2)

            x = x_stack


        return x


class CrossRandomization(nn.Module):
    def __init__(self, eps=1e-5):
        super(CrossRandomization, self).__init__()
        self.eps = eps
    def forward(self, x, select_images, K = 10):
        N, C, dim = x.size()
        _, select_num, _ = select_images.size()
        #print("select_num:",select_num)
        idx_swap = torch.randperm(N)
        if self.training:
            x = x.view(N, C, -1)
            x_mean = x.mean(-1, keepdim=True)
            x_var = x.var(-1, keepdim=True)
            x = (x - x_mean) / (x_var + self.eps).sqrt()

            select_mean = select_images.mean(-1, keepdim=True)
            select_var = select_images.var(-1, keepdim=True)
            # if x.is_cuda:
            #     select_mean = torch.FloatTensor(torch.zeros(N, C, 1)).cuda()
            #     select_var = torch.FloatTensor(torch.zeros(N, C, 1)).cuda()
            # else:
            #     select_mean = torch.FloatTensor([0])
            #     select_var = torch.FloatTensor([0])
            # for i in range(select_num):
            #     select_image_i = select_images[:,i,:]
            #     print("select_image_i",select_image_i.size())
            #     select_image = select_image_i.view(N, C, -1)
            #     select_mean_i = select_image.mean(-1, keepdim=True)
            #     select_var_i = select_image.var(-1, keepdim=True)
            #     select_mean += select_mean_i
            #     select_var += select_var_i
            # select_mean = select_mean / select_num
            # select_var = select_var / select_num

            beta = 0.1
            alpha = torch.rand(N, 1, 1) / K
            if x.is_cuda:
                alpha = alpha.cuda()
            # print("x_mean:",x_mean.size())
            # print("select_mean:",select_mean.size())

            x_mean = (1 - alpha) * x_mean + alpha * select_mean
            x_var = (1 - alpha) * x_var + alpha * select_var

            # mean = (1 - alpha) * x_mean + alpha * x_mean[idx_swap]
            # var = (1 - alpha) * x_var + alpha * x_var[idx_swap]

            x = x * (x_var + self.eps).sqrt() + x_mean
            x = x.view(N, C, dim)
        return x


class CrossRandomization_3D(nn.Module):
    def __init__(self, eps=1e-5):
        super(CrossRandomization_3D, self).__init__()
        self.eps = eps
    def forward(self, x, select_images, K = 10):
        N, C, T, H, W = x.size()
        _, select_num, _, _, _ = select_images.size()
        idx_swap = torch.randperm(N)
        if self.training:
            if x.is_cuda:
                select_mean = torch.FloatTensor(torch.zeros(N, C, 1)).cuda()
                select_var = torch.FloatTensor(torch.zeros(N, C, 1)).cuda()
            else:
                select_mean = torch.FloatTensor([0])
                select_var = torch.FloatTensor([0])
            for i in range(select_num):
                select_image_i = select_images[:,i,:,:,:]
                select_image = select_image_i.view(N, C, -1)
                select_mean_i = select_image.mean(-1, keepdim=True)
                select_var_i = select_image.var(-1, keepdim=True)
                select_mean += select_mean_i
                select_var += select_var_i
            select_mean = select_mean / select_num
            select_var = select_var / select_num

            x_stack = []
            beta = 0.007
            for i in range(T):
                x_i = x[:, :, i, :, :]
                x_i = x_i.view(N, C, -1)
                mean = x_i.mean(-1, keepdim=True)
                var = x_i.var(-1, keepdim=True)

                x_i = (x_i - mean) / (var + self.eps).sqrt()

                mean = (1 - beta) * mean + beta * select_mean
                var = (1 - beta) * var + beta * select_var
                alpha = torch.rand(N, 1, 1) / K
                if x_i.is_cuda:
                    alpha = alpha.cuda()
                mean = (1 - alpha) * mean + alpha * mean[idx_swap]
                var = (1 - alpha) * var + alpha * var[idx_swap]
                x_i = x_i * (var + self.eps).sqrt() + mean
                x_i = x_i.view(N, C, H, W)
                x_stack.append(x_i)
            x_stack = torch.stack(x_stack, dim=2)
            x = x_stack

        return x


def get_video_spatial_feature(featmap_H, featmap_W):
    import numpy as np
    spatial_batch_val = np.zeros((1, 8, featmap_H, featmap_W))
    for h in range(featmap_H):
        for w in range(featmap_W):
            xmin = w / featmap_W * 2 - 1
            xmax = (w + 1) / featmap_W * 2 - 1
            xctr = (xmin + xmax) / 2
            ymin = h / featmap_H * 2 - 1
            ymax = (h + 1) / featmap_H * 2 - 1
            yctr = (ymin + ymax) / 2
            spatial_batch_val[0, :, h, w] = [xmin, ymin, xmax, ymax, xctr, yctr, 1 / featmap_W, 1 / featmap_H]
    return torch.from_numpy(spatial_batch_val).float()


class PosEmb(nn.Module):
    def __init__(self):
        super().__init__()
        self.pos10 = nn.Parameter(get_video_spatial_feature(10, 10), requires_grad=False)
        self.pos20 = nn.Parameter(get_video_spatial_feature(20, 20), requires_grad=False)
        self.pos40 = nn.Parameter(get_video_spatial_feature(40, 40), requires_grad=False)
        self.pos80 = nn.Parameter(get_video_spatial_feature(80, 80), requires_grad=False)
        self.pos160 = nn.Parameter(get_video_spatial_feature(160, 160), requires_grad=False)

    def forward(self, x):
        bsz, dim, h, w = x.size()
        if h == 10:
            pos_emb = self.pos10
        elif h == 20:
            pos_emb = self.pos20
        elif h == 40:
            pos_emb = self.pos40
        elif h == 80:
            pos_emb = self.pos80
        elif h == 160:
            pos_emb = self.pos160
        pos_emb = pos_emb.expand(bsz, 8, h, w).cuda(x.device)
        return torch.cat([x, pos_emb], 1)


class QueryGuidedRegionAttention2D(nn.Module):
    def __init__(self, src_dim1, src_dim2, hidden_size, region_self=True):
        super().__init__()
        self.fc_input1 = nn.Linear(src_dim1, hidden_size)
        self.fc_input2 = nn.Linear(src_dim2, hidden_size)

        # self.fc1 = nn.Linear(src_dim1, hidden_size)
        self.attn = VisionGuidedAttention(hidden_size + 8, hidden_size, hidden_size)

        self.pos_emb = PosEmb()

        self.fc_q = nn.Linear(hidden_size + 8, hidden_size)
        self.fc_k = nn.Linear(hidden_size + 8, hidden_size)
        self.fc_v = nn.Linear(hidden_size + 8, hidden_size)

        self.fc1 = nn.Linear(hidden_size, hidden_size << 1)
        self.fc2 = nn.Linear(hidden_size << 1, hidden_size)

        self.hidden_size = hidden_size
        self.region_self = region_self
        self.fc_o = nn.Conv2d(hidden_size, src_dim1, kernel_size=1, stride=1, padding=0)
        self.query_linear = nn.Linear(src_dim2, hidden_size)
        #self.important_linear = nn.Linear(hidden_size, 1)
        self.video_linear = nn.Linear(hidden_size, hidden_size)
        #self.affine = nn.Linear(hidden_size, hidden_size)
        self.attention_linar = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(0.1)
        self.style_random_layer = StyleRandomization()
        #self.background_random = BackgroundStyleRandom()

    def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask, summary_query = None, IN = False, Aug = False,IN_query = False, epoch = 0):
        bsz, dim, h, w = clip_h.size()

        x = self.fc_input1(clip_h.transpose(-2, -3).transpose(-1, -2))
        query = self.fc_input2(query)

        # region self-attention
        res = x
        x = self.pos_emb(x.transpose(-1, -2).transpose(-2, -3)).transpose(-2, -3).transpose(-1, -2)
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, x.size(-1))
        x = x.reshape(bsz, h * w, -1)
        cluster_emb = scatter(src=x, index=segment_, dim=1, reduce='max')
        cluster_len = torch.max(segment.reshape(bsz, -1), dim=-1)[0] + 1
        cluster_mask = generate_mask(cluster_emb, cluster_len)
        q, k = self.fc_q(cluster_emb), self.fc_k(cluster_emb)
        v = self.fc_v(cluster_emb)
        s = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
        s = F.softmax(s, dim=-1)
        x = torch.matmul(s, v)
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, x.size(-1))
        x = x.gather(dim=1, index=segment_)
        x = x.reshape(bsz, h, w, -1)
        # x = F.dropout(x, p=0.2, training=self.training)
        x = (res + x)  # [nb, h, w, dim]


        # query-guided visual attention
        if summary_query is not None:
            res = x
            #x = x.permute(0, 3, 1, 2)
            #global_visual_feature = F.adaptive_avg_pool2d(x, (1,1)).squeeze(2).squeeze(-1)
            #print("global_visual_feature:",global_visual_feature.size())
            #summary_query = summary_query + self.dropout(self.sum_linear(torch.cat([summary_query, global_visual_feature], dim=-1)))
            #x = x.permute(0, 2, 3, 1)
            x = x.reshape(bsz, h*w, -1)
            summary_query = self.relu(self.query_linear(summary_query)).unsqueeze(-2)
            #print("summary_query:",summary_query.size())
            x_feature = self.relu(self.video_linear(x))
            query_x_raw = x_feature * summary_query
            # print("query_x_raw:",query_x_raw.size())

            query_att_map = (self.attention_linar(query_x_raw).sigmoid())  #(bsz, h*w, channel)
            if self.training:
                if IN == True:
                    query_att_map = query_att_map.reshape(bsz, h, w, -1)
                    # query_att_map = query_att_map.permute(0, 3, 1, 2)
                    query_att_map = self.style_random_layer(query_att_map, K=2)  # 维度错误，重新更改
                    # query_att_map = query_att_map.permute(0, 2, 3, 1)
                    query_att_map = query_att_map.reshape(bsz, h * w, -1)
                # if Aug == True:
                #     with torch.no_grad():
                #         orginal_important_area = (self.important_linear(query_att_map).sigmoid() > 0.10)
                #         unimportant_area = ~orginal_important_area
                #         unimportant_area = unimportant_area.float()
                #
                #
                #     with torch.no_grad():
                #         idx_swap = torch.randperm(bsz)
                #         query_generlization_raw = x_feature[idx_swap] * summary_query
                #         query_generlization_map = (
                #             self.attention_linar(query_generlization_raw).sigmoid())  # (bsz, h*w, 1)
                #         one_map = torch.ones(bsz, h * w, 1).cuda()
                #         background_map = one_map - query_generlization_map
                #         background_feat = x_feature[idx_swap] * background_map
                #         #x_aug = x_feature * query_att_map + background_feat * unimportant_area
                #         x_aug = x_feature + background_feat * unimportant_area
                #         #x_feature = x_feature + background_feat*unimportant_area


            x_feat = x_feature * query_att_map
            # if self.training and Aug == True:
            #     x_feat = x_aug



            # print("x_feat:", x_feat.size())
            # print("query_att_map:",query_att_map.size())
            x = x_feat.reshape(bsz, h, w, -1) + res


        # video-query attention
        res = x
        input_x = self.pos_emb(x.transpose(-1, -2).transpose(-2, -3))
        x = self.attn(input_x.transpose(-2, -3).transpose(-1, -2),
                      None, query, query_mask, IN = Aug)  # [nb, h, w, dim]
        x = F.dropout(x, p=0.1, training=self.training)  # 测试加dropout是否有效果
        x = res + x

        res = x
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.dropout(x, p=0.1, training=self.training)
        x = res + x

        # if self.training and Aug == True and epoch >= 0:
        #     x_feature = x_feature.reshape(bsz, h, w, -1).transpose(-1, -2).transpose(-2, -3)
        #     x_aug = x_aug.reshape(bsz, h, w, -1).transpose(-1, -2).transpose(-2, -3)
        #     return self.fc_o(x.transpose(-1, -2).transpose(-2, -3)), x_feature, x_aug, summary_query

        return self.fc_o(x.transpose(-1, -2).transpose(-2, -3))


class FinalModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self._build_clip_encoder()
        self._build_query_encoder()

        self.ali_predictor = {}

        self.ori_r = [8, 16, 32, 64, 128]
        self.ori_r = [10, 20, 40, 80, 160]

        for (a, b, c) in [(self.ori_r[-1] * 2, 64, 64),
                          (self.ori_r[-2] * 2, 64, 64),
                          (self.ori_r[-3] * 2, 192, 128),
                          (self.ori_r[-4] * 2, 480, 256),
                          (self.ori_r[-5] * 2, 832, 512)]:
            self.ali_predictor['pred_{}'.format(a)] = FullyConvolution(b, c, 1)

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # self.ali_predictor2 = FullyConvolution(config['hidden_size'],
        #                                        config['hidden_size'], 1)

        self.fconv = {}

        self.ra = {}
        self.resolution = {
            self.ori_r[0]: 1024,
            self.ori_r[1]: 832,
            self.ori_r[2]: 480,
            self.ori_r[3]: 192,
            self.ori_r[4]: 64
        }

        def get_video_spatial_feature(featmap_H, featmap_W):
            import numpy as np
            spatial_batch_val = np.zeros((1, 8, featmap_H, featmap_W))
            for h in range(featmap_H):
                for w in range(featmap_W):
                    xmin = w / featmap_W * 2 - 1
                    xmax = (w + 1) / featmap_W * 2 - 1
                    xctr = (xmin + xmax) / 2
                    ymin = h / featmap_H * 2 - 1
                    ymax = (h + 1) / featmap_H * 2 - 1
                    yctr = (ymin + ymax) / 2
                    spatial_batch_val[0, :, h, w] = [xmin, ymin, xmax, ymax, xctr, yctr, 1 / featmap_W, 1 / featmap_H]
            return torch.from_numpy(spatial_batch_val).float()

        for res in self.resolution.keys():
            self.ra['RegionAttention_{}'.format(res)] = \
                QueryGuidedRegionAttention2D(self.resolution[res], self.config['query_output_dim'] << 1,
                                             config['hidden_size'])
            # if res != self.ori_r[-1]:
            self.ra['RegionAttention_{}_i3d'.format(res)] = \
                QueryGuidedRegionAttention2D(self.resolution[res], self.config['query_output_dim'] << 1,
                                             config['hidden_size'], region_self=True)
            if res != self.ori_r[-1]:
                self.fconv['Conv_{}'.format(res)] = \
                    Conv2d(self.resolution[res], self.resolution[res * 2], kernel_size=1, stride=1, padding=0)

        for k in self.ra.keys():
            self.add_module(k, self.ra[k])
        for k in self.fconv.keys():
            self.add_module(k, self.fconv[k])
        for k in self.ali_predictor.keys():
            self.add_module(k, self.ali_predictor[k])

        self.style_random_layer_3d = StyleRandomization_3D()
        self.query_gru = nn.LSTM(input_size=768, hidden_size=1024,
                                 num_layers=1, batch_first=True, bidirectional=True, dropout=0.3)
        self.text_prompt_gru = nn.LSTM(input_size=512, hidden_size=256, num_layers=1, batch_first=True, bidirectional=True, dropout=0.2)
        self.text_prompt_encoder = nn.MultiheadAttention(512, num_heads=4, dropout=0.2)
        #self.obj_prompt_encoder = nn.MultiheadAttention(512, num_heads=4, dropout=0.2)
        self.obj_cross_attention = CrossAttention(d_model=512, nhead=4)
        self.obj_gru2 = nn.GRU(input_size=512, hidden_size=256, num_layers=1, batch_first=True, bidirectional=True, dropout=0.2)

        self.clip_model = load_clip_to_cpu()
        self.clip_model.float()
        self.clip_model.cuda()
        self.clip_image_encoder = self.clip_model.visual
        self.logit_scale = self.clip_model.logit_scale
        self.dtype = self.clip_model.dtype

        self.prompt_learner = PromptLearner(self.clip_model)
        self.text_encoder = TextEncoder(self.clip_model)

        self.text_cross_attention = CrossAttention(d_model=512, nhead=4)
        self.visual_cross_attention1 = CrossAttention(d_model=512, nhead=4)
        self.prompt_cross_attention = CrossAttention(d_model=512, nhead=4)

        self.visual_self_attention1 = CrossAttention(d_model=512, nhead=4)
        self.visual_cross_attention2 = CrossAttention(d_model=512, nhead=4)


        self.obj_gru = nn.GRU(input_size=512, hidden_size=256, num_layers=1, batch_first=True, bidirectional=True, dropout=0.2)
        self.sub_gru = nn.GRU(input_size=512, hidden_size=256, num_layers=1, batch_first=True, bidirectional=True, dropout=0.2)
        #self.mask_gru = nn.GRU(input_size=512, hidden_size=512, num_layers=1, batch_first=True, bidirectional=True, dropout=0.2)

        #self.ff_query= nn.Linear(2048, 512)


        self.softmax = nn.Softmax(dim=0)
        self.Conv_layer1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=3, padding=9, bias=True)
        self.Conv_layer2 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, stride=1, padding=1, bias=True)
        #self.Conv_test_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=49, bias=True)
        self.norm_test = nn.BatchNorm2d(64)
        self.relu_test = nn.ReLU(inplace=True)
        #self.Conv_test_2 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, stride=2, padding=1, bias=True)

        self.object_visual_drop = nn.Dropout(0.2)
        self.object_visual_norm = nn.LayerNorm(512)
        #self.whitening_layer = SAW(dim=192, relax_denom=100000.0, work=True)
        self.Select_object = Select_object(clip_image_encoder=self.clip_image_encoder, logit_scale=self.logit_scale)
        self.CrossRandomization = CrossRandomization()
        self.CrossRandomization_3d = CrossRandomization_3D()
        self.style_random_layer = StyleRandomization()



    def _forward_impl(self, x):
        return x

    def load_pretrained_weights(self):

        path = '/home/user/openset-dg/i3d_rgb.pth'
        state_dict = torch.load(path)
        print("-------------------------------------------")
        print(self.i3d.load_state_dict(state_dict, strict=True))

    def forward(self, clip, clip_image, maskrcnn_object_list, original_query_text, obj_query_text, sub_query_text, query, query_len,
                coarse_gt_mask=None, fine_gt_mask=None, anch_mask=None, mask=None, segment=None, epoch=None,
                **kwargs):
        #print("query:",query.size())
        logit_scale = self.logit_scale.exp()
        query_mask = generate_mask(query, query_len)
        #query_aug = random_swap(query, n=1)
        #print("query_aug:",query_aug.size())
        #query_h_aug = self.gru(query_aug, query_len)
       # _, (summary_query_aug, _) = self.query_gru(query_aug)

        query_h = self.gru(query, query_len)
        bs = query_h.size(0)
        _, (summary_query, _) = self.query_gru(query)
        summary_query = summary_query.transpose(0, 1).reshape(bs, -1)
        #summary_query_aug = summary_query_aug.transpose(0, 1).reshape(bs, -1)

        # if self.training and epoch >= 6:
        #     if_aug = random.randint(1, 2)
        #     if if_aug == 1:
        #         query_h = query_h_aug
        #         summary_query = summary_query_aug
        #     else:
        #         pass

        clip_image = clip_image.permute(0, 3, 1, 2).float()
        with torch.no_grad():

            clip_image_feature = self.clip_image_encoder(clip_image)
        batch_size, length, c, h, w = maskrcnn_object_list.size()
        obj_img_list = maskrcnn_object_list.reshape(batch_size * length, c, h, w) # (16, 3, 224, 224)
        obj_img_list2 = obj_img_list.clone()
        #obj_img_list2 = obj_img_list2.reshape(batch_size, length, c, h, w)





        #print("obj_img_list:",obj_img_list.size())
        with torch.no_grad():
            object_feature = self.clip_image_encoder(obj_img_list)  # (16, 512)
            original_object_image_feature = object_feature
            original_object_image_feature = original_object_image_feature.reshape(batch_size, length, -1)


        #print("object_feature:",object_feature.size())



        prompts = self.prompt_learner(clip_image_feature, obj_query_text, sub_query_text)
        text_result = self.text_encoder(prompts)


        #print("text_feature:",text_result.size())
        obj_part = text_result[:,0:16,:]
        other_part = text_result[:,16:,:]
        object_feature = object_feature.reshape(batch_size, length, -1)
        clip_image_feature = clip_image_feature.unsqueeze(1)
        object_feature = torch.cat([clip_image_feature, object_feature],dim=1).permute(1, 0, 2) # [5, 4, 512]

        clip_image_feature = clip_image_feature.squeeze(1)

        clip_image_feature = clip_image_feature.unsqueeze(1).transpose(0, 1)
        #**********************************************************************
        #text_result = self.style_random_layer(text_result, K=5)
        # **********************************************************************
        text_result = text_result.transpose(0, 1)
        obj_result = obj_part.transpose(0, 1)
        other_part = other_part.permute(1, 0, 2)  # [61, 4, 512]
        text_features = self.text_prompt_encoder(clip_image_feature, text_result, text_result)[0]
        #obj_features = self.obj_prompt_encoder(clip_image_feature, obj_result, obj_result)[0]
        (_, obj_part_gru) = self.obj_gru(obj_result.transpose(0, 1))
        obj_result = self.obj_cross_attention(obj_result, other_part)
        (_, obj_part_gru_2) = self.obj_gru2(obj_result.transpose(0,1))
        obj_part_gru = obj_part_gru.transpose(0, 1).reshape(batch_size, -1)
        obj_part_gru_2 = obj_part_gru_2.transpose(0, 1).reshape(batch_size, -1)
        #print("obj_part_gru:",obj_part_gru.size())

        clip_image_feature = clip_image_feature.transpose(0, 1).squeeze(1)


        if self.training and epoch >= 10:
            with torch.no_grad():
                sum_select_images = self.Select_object(text_features, obj_img_list2)  # [batch, num, dim]
                sum_select_images = sum_select_images.squeeze(0)
                sum_select_images = sum_select_images.permute(1, 0, 2, 3, 4)
                # print("sum_select_images:",sum_select_images.size())
                batch_size, num, c, h, w = sum_select_images.size()


        if self.training and epoch >= 13:
            with torch.no_grad():
                object_feature = object_feature.permute(1, 0, 2)  # [4, 5, 512]
                sum_select_images = sum_select_images.reshape(batch_size * num, c, h, w)
                sum_select_features = self.clip_image_encoder(sum_select_images)
                sum_select_features = sum_select_features.reshape(batch_size, num, -1)
                object_feature = self.CrossRandomization(object_feature, sum_select_features, K=10)
                object_feature = object_feature.permute(1, 0, 2)



        #object_feature = torch.cat([object_feature, clip_image_feature], dim=1).permute(1, 0, 2)
        #object_feature = object_feature.permute(1, 0, 2)

        #other_part_attention = self.text_cross_attention(other_part, object_feature).permute(1, 0, 2)
        object_feature_attention = self.visual_cross_attention1(object_feature, other_part)  #[5, 4, 512]
        object_feature_attention = self.visual_self_attention1(object_feature_attention, object_feature_attention)
        object_feature_attention = self.visual_cross_attention2(object_feature_attention, other_part).permute(1, 0, 2)

        object_feature_attention2 = object_feature_attention


        #(_, other_part_gru) = self.sub_gru(other_part_attention)
        #other_part_gru = other_part_gru.transpose(0, 1).reshape(batch_size, -1)
        #(_, object_feature_gru) = self.mask_gru(object_feature_attention)
        #object_feature_gru= object_feature_gru.transpose(0, 1).reshape(batch_size, -1)
        object_feature_gru = object_feature_attention[:,0,:]
        #print("object_feature_gur:",object_feature_gru.size())

        label_truth = torch.arange(batch_size).cuda()
        loss_other = nn.CrossEntropyLoss().cuda()
        loss_mask = nn.CrossEntropyLoss().cuda()

        #logits_other = logit_scale * obj_part_gru @ other_part_gru.t()
        logits_mask = logit_scale * obj_part_gru @ object_feature_gru.t()
        #prompt_loss = (loss_other(logits_other, label_truth) + loss_mask(logits_mask, label_truth))/2
        prompt_loss = loss_mask(logits_mask, label_truth)


        #original_query_feature = self.ff_query(query_h)
        #object_feature_attention = torch.cat([other_part_attention, object_feature_attention],dim=1).permute(1, 0, 2)
        #替换成other_part_attention?
        #obj_part = obj_part.permute(1, 0, 2)
        #obj_part = self.prompt_cross_attention(obj_part, object_feature_attention).permute(1, 0, 2)


        object_feature = object_feature.permute(1, 0, 2) + self.object_visual_drop(object_feature_attention2) #[4, 5, 512]
        object_feature = object_feature[:, 1:5, :]
        object_feature = self.object_visual_norm(object_feature)



        #text_features = text_features.transpose(0, 1).squeeze(1)
        #obj_features = obj_features.transpose(0, 1).squeeze(1)



        #clip_image_feature = clip_image_feature / clip_image_feature.norm(dim=1, keepdim=True)  #(4, 512)
        #print("clip_image_feature:",clip_image_feature.size())
        #text_features = text_features / text_features.norm(dim=1, keepdim=True)
        obj_features = obj_part_gru_2 / obj_part_gru_2.norm(dim=1, keepdim=True)

        obj_img_list = obj_img_list.reshape(batch_size, length, c, h, w)



        sum_image = []
        for b in range(batch_size):
            obj_per_image = object_feature[b, :, :]
            obj_list_per_image = obj_img_list[b, :, :, :, :]

            obj_per_image = obj_per_image / obj_per_image.norm(dim=1, keepdim=True)

            obj_feature = obj_features[b, :]

            logits_per_image = self.logit_scale * obj_per_image @ obj_feature.t()


            logits_per_image = self.softmax(logits_per_image).unsqueeze(1).cuda()
            logits_per_image = logits_per_image * 4
            #logits_per_image = logits_per_image.unsqueeze(1).cuda()
            sum_per_image = torch.zeros((3, 224, 224)).cuda()

            for i in range(len(logits_per_image)):
                sum_per_image += logits_per_image[i] * obj_list_per_image[i]
            # sum_per_image = sum_per_image.sum(dim=0)
            sum_image.append(sum_per_image)
        sum_image = torch.stack(sum_image)

        sum_image = self.Conv_layer1(sum_image)
        sum_image = self.Conv_layer2(sum_image)     #[4, 192, 80, 80]


        # if self.training and epoch >=58:
        #     with torch.no_grad():
        #         sum_select_images = self.Select_object(text_features, obj_img_list2)
        #
        # if self.training and epoch >= 60:
        #     with torch.no_grad():
        #         batch_size, select_number, dim, h, w = sum_select_images.size()
        #         sum_select_images = sum_select_images.view(batch_size*select_number, dim, h, w)
        #         sum_select_images = self.Conv_layer1(sum_select_images)
        #         sum_select_images = self.Conv_layer2(sum_select_images)
        #         _, dim, h, w = sum_select_images.size()
        #         sum_select_images = sum_select_images.view(batch_size, select_number, dim, h, w)
        #         sum_image = self.CrossRandomization(sum_image, sum_select_images, K=20)


        # if self.training and epoch >= 6:
        #     whitening_loss = self.whitening_layer(sum_image)
        # else:
        #     whitening_loss = torch.FloatTensor([0]).cuda()


        #sum_image = sum_image+text_result

        #
        # logits_per_image = logit_scale * clip_image_feature @ text_features.t()
        # logits_per_text = logits_per_image.t()
        # label_truth = torch.arange(bs).cuda()
        # loss_img = nn.CrossEntropyLoss().cuda()
        # loss_txt = nn.CrossEntropyLoss().cuda()
        # clip_loss = (loss_img(logits_per_image, label_truth) + loss_txt(logits_per_text, label_truth))/2

        multi_res = {}

        def add_pos_emb(x):
            return x

        clip_h = clip         # (4,3,8,320,320)
        #print("clip_h1:",clip_h.size())
        clip_h = self.i3d.conv3d_1a_7x7(clip_h)
        # print("clip_h2:", clip_h.size())

        clip_h = add_pos_emb(clip_h)  # (4, 64, 4, 160, 160)
        # *****************************************************************
        # if self.training and epoch < 100:
        #     clip_h = self.style_random_layer_3d(clip_h, K=5)  # 测试3D AdaIN的效果
        # *****************************************************************


        res = self.ori_r[-1]
        x = clip_h.mean(dim=2)  # [batch, channel, H, W]  (4, 64, 160, 160)
        # *****************************************************************
        # x = x.permute(1, 0, 2, 3)
        # x = self.style_random_layer(x)  # 测试AdaIN的效果
        # x = x.permute(1, 0, 2, 3)
        # *****************************************************************

        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                    query_mask, summary_query)

        clip_h = (clip_h + x.unsqueeze(2))


        multi_res[res] = clip_h.mean(dim=2)

        clip_h = self.i3d.maxPool3d_2a_3x3(clip_h)
        clip_h = self.i3d.conv3d_2b_1x1(clip_h)
        clip_h = self.i3d.conv3d_2c_3x3(clip_h)

        clip_h = add_pos_emb(clip_h) #  (4, 192, 4, 80, 80)
        #original_clip_h = clip_h
        # if self.training and epoch >= 20:
        #     with torch.no_grad():
        #         sum_select_images = self.Select_object(text_features, obj_img_list2)
        # if self.training and epoch >= 23:
        #     with torch.no_grad():
        #         batch_size, select_number, dim, h, w = sum_select_images.size()
        #         sum_select_images = sum_select_images.view(batch_size * select_number, dim, h, w)
        #         sum_select_images = self.Conv_layer1(sum_select_images)
        #         sum_select_images = self.Conv_layer2(sum_select_images)
        #         _, dim, h, w = sum_select_images.size()
        #         sum_select_images = sum_select_images.view(batch_size, select_number, dim, h, w)
        #         clip_h = self.CrossRandomization_3d(clip_h, sum_select_images, K=4)

        # if random.randint(0,1) == 0:
        #     clip_h = original_clip_h
        res = self.ori_r[-2]
        x = clip_h.mean(dim=2)   # (4, 192, 80, 80)
        #x = visual_concat(x, sum_image)
        x = sum_image
        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query)

        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        # print(out.shape) #192, 8, 128, 128 when 512 input
        clip_h = self.i3d.maxPool3d_3a_3x3(clip_h)
        clip_h = self.i3d.mixed_3b(clip_h)
        clip_h = self.i3d.mixed_3c(clip_h)

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-3]
        x = clip_h.mean(dim=2)  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query)
        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        # print(out.shape) #480, 8, 64, 64 when 512 input
        clip_h = self.i3d.maxPool3d_4a_3x3(clip_h)
        # print(out.shape) #480, 4, 32, 32 when 512 input
        clip_h = self.i3d.mixed_4b(clip_h)
        clip_h = self.i3d.mixed_4c(clip_h)
        clip_h = self.i3d.mixed_4d(clip_h)
        clip_h = self.i3d.mixed_4e(clip_h)
        # print(out.shape) #528, 4, 32, 32 when 512 input
        clip_h = self.i3d.mixed_4f(clip_h)

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-4]
        x = clip_h.mean(dim=2)  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())




        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query)
        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        # print(out.shape) #832, 4, 32, 32 when 512 input
        clip_h = self.i3d.maxPool3d_5a_2x2(clip_h)
        clip_h = self.i3d.mixed_5b(clip_h)
        clip_h = self.i3d.mixed_5c(clip_h)

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-5]
        x = clip_h.mean(dim=2)  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query)
        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        ali_score_map = {}
        x = 0.0
        for layer_idx, res in enumerate(self.resolution):
            x = x  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
            x = (x + multi_res[res])
            x = x + self._modules['RegionAttention_{}'.format(res)](x, mask[res], segment[res],
                                                                    query_h, query_len, query_mask)
            if res != self.ori_r[-1]:
                x = self._modules['Conv_{}'.format(res)](x, mask[res])
            x = self.up(x)
            if 'pred_{}'.format(res * 2) in self.ali_predictor.keys():
                input_x = x
                # if res != 128:
                #     input_x = torch.cat([input_x, multi_res[res * 2]], dim=1)
                ali_score_map[res * 2] = self._modules['pred_{}'.format(res * 2)](input_x, mask[res * 2])
                ali_score_map[res * 2] = torch.sigmoid(
                    ali_score_map[res * 2])  #
                if not self.training:
                    ali_score_map[res * 2] = ali_score_map[res * 2].masked_fill(mask[res * 2].unsqueeze(1) == 0, 0)

        final_dict = {
            'ali_score_map': ali_score_map,
            # 'fix_score_map': fix_score_map,
            'fine_gt_mask': fine_gt_mask,
            'coarse_gt_mask': coarse_gt_mask,
            'mask': mask
        }

        if self.training and False:
            contrast_score, _, _ = self._contrastive_score(emb, mask[self.ori_r[-1] * 2], anch_mask[self.ori_r[-1] * 2],
                                                           fg_score_map)
            final_dict.update({
                'contrast_score': contrast_score,
                # 'diversity_loss': diversity_loss,
                # 'same_loss': same_loss,
            })
        if self.training and epoch >=0:
            return final_dict, prompt_loss

        return final_dict, prompt_loss

    def _contrastive_score(self, clip_h, mask, anch_mask, fg_score_map):

        bsz, dim, h, w = clip_h.size()
        anch_mask = anch_mask.unsqueeze(1)

        mask1 = (anch_mask == 1).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1)
        diversity_loss = -(score * torch.log(score + 1e-10)).sum(dim=-1).mean(dim=0)
        score = score.reshape(bsz, 1, h, w)
        anchor_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
        anchor_emb = F.normalize(anchor_emb, dim=-1)

        mask1 = (anch_mask == 2).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
        pos_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)

        # print(anchor_emb.size(), pos_emb.size())

        exist_mask = [torch.ones(bsz, 1).type_as(mask)]
        contrast_emb = [pos_emb]

        same_loss = 0.0

        for neg_idx in [3, 4, 5, 6]:
            mask1 = (anch_mask == neg_idx).long()
            is_exist = ((mask1 == 1).sum(dim=-1).sum(dim=-1) > 0).long()
            score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
            score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
            same_loss += ((-(score * torch.log(score + 1e-10)).sum(dim=-1))
                          * is_exist.float()).sum() / (is_exist.sum().float() + 1e-10)

            neg_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
            contrast_emb.append(neg_emb)
            exist_mask.append(is_exist)
        same_loss /= 4
        exist_mask = torch.cat(exist_mask, dim=1)

        contrast_emb = F.normalize(torch.stack(contrast_emb, dim=1), dim=-1)

        score = torch.matmul(anchor_emb.unsqueeze(1), contrast_emb.transpose(-1, -2)).squeeze(1) * self.lambda_.cuda()
        score = F.softmax(score.masked_fill(exist_mask == 0, float('-1e30')), dim=-1)
        return score[:, 0], diversity_loss, same_loss

    def _build_clip_encoder(self):
        self.i3d = I3D(num_classes=400, modality='rgb')

    def _build_query_encoder(self):
        self.word2vec = nn.Embedding(1200, self.config['query_dim'], padding_idx=0)
        self.gru = DynamicGRU(self.config['query_dim'], self.config['query_output_dim'],
                              num_layers=1, bidirectional=True, batch_first=True)


class VisionGuidedAttention(nn.Module):
    def __init__(self, src_dim, src_dim2, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(src_dim, self.hidden_size)
        self.fc2 = nn.Linear(src_dim2, self.hidden_size)
        self.fc3 = nn.Linear(src_dim2, self.hidden_size)
        self.fco = nn.Linear(self.hidden_size, src_dim - 8)
        self.style_random_layer = StyleRandomization()

    def forward(self, clip, clip_mask, query, query_mask=None, IN = False):
        bsz, h, w, _ = clip.size()
        l = query.size(1)

        nh = 1

        a = self.fc1(clip)
        b = self.fc2(query)
        c = self.fc3(query)
        # print(a.size(), b.size())
        # exit(0)
        a = a.reshape(bsz, h * w, nh, -1)
        b = b.reshape(bsz, l, nh, -1)
        c = c.reshape(bsz, l, nh, -1)
        # score = torch.matmul(a, b.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        score = torch.einsum('bihd,bjhd->bijh', a, b) / math.sqrt(self.hidden_size // nh)  # b,h*w,l,nh

        if self.training and IN == True:
            score = score.permute(0, 3, 1, 2)
            score = self.style_random_layer(score, K = 2)
            score = score.permute(0, 2, 3, 1)




        if query_mask is not None:
            score = score.masked_fill_(query_mask.unsqueeze(-1).unsqueeze(1) == 0,
                                       float('-inf'))
        score = F.softmax(score, -2)

        query_ = torch.einsum('bijh,bjhd->bihd', score, c).reshape(bsz, h * w, -1)

        return self.fco(query_.reshape(bsz, h, w, -1))


class TanhAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # self.dropout = nn.Dropout(dropout)
        self.ws1 = nn.Linear(d_model, d_model, bias=True)
        self.ws2 = nn.Linear(d_model, d_model, bias=False)
        self.wst = nn.Linear(d_model, 1, bias=False)

    def reset_parameters(self):
        self.ws1.reset_parameters()
        self.ws2.reset_parameters()
        self.wst.reset_parameters()

    def forward(self, x, memory, memory_mask=None, fast_weights=None, **kwargs):
        if fast_weights is None:
            item1 = self.ws1(x)  # [nb, len1, d]
            item2 = self.ws2(memory)  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = self.wst(torch.tanh(item)).squeeze(-1)  # [nb, len1, len2]
        else:
            item1 = F.linear(x, fast_weights['ws1.weight'], fast_weights['ws1.bias'])  # [nb, len1, d]
            item2 = F.linear(memory, fast_weights['ws2.weight'])  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = F.linear(torch.tanh(item), fast_weights['wst.weight']).squeeze(-1)  # [nb, len1, len2]
        if memory_mask is not None:
            memory_mask = memory_mask.unsqueeze(1)  # [nb, 1, len2]
            S = S.masked_fill(memory_mask == 0, float('-inf'))
        S = F.softmax(S, -1)
        return torch.matmul(S, memory), S  # [nb, len1, d]


class CrossGate(nn.Module):
    def __init__(self, h1, h2):
        super().__init__()
        self.g1 = nn.Linear(h2, h1)
        self.g2 = nn.Linear(h1, h2)

    def forward(self, x1, x2):
        return x1 * torch.sigmoid(self.g1(x2)), x2 * torch.sigmoid(self.g2(x1))


def generate_mask(x, x_len):
    if False and int(x_len.min()) == x.size(1):
        mask = None
    else:
        mask = []
        for l in x_len:
            mask.append(torch.zeros([x.size(1)]).long())
            mask[-1][:l] = 1
        mask = torch.stack(mask, 0).cuda()
    return mask
#
# import random
# import copy
# def swap_word(new_words):
#     random_idx_1 = random.randint(0, len(new_words) - 1)
#     random_idx_2 = random_idx_1
#     counter = 0
#     while random_idx_2 == random_idx_1:
#         random_idx_2 = random.randint(0, len(new_words) - 1)
#         counter += 1
#         if counter >= 2:
#             return new_words
#     new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
#     return new_words
#
# def random_swap(words, n):
#     new_words = copy.deepcopy(words)
#     for _ in range(n):
#         new_words = swap_word(new_words)
#     return new_words
#



def generate_coordinate_emb(clip, h, w):
    bsz = clip.size(0)
    x = (2 * (torch.linspace(0, h, h).type_as(clip) / h) - 1)
    x = x.unsqueeze(-1).expand(h, w)
    y = (2 * (torch.linspace(0, w, w).type_as(clip) / w) - 1)
    y = y.unsqueeze(0).expand(h, w)
    co = torch.stack([x, y], -1)
    co = F.normalize(co, -1)
    co = co.unsqueeze(0).expand(bsz, -1, -1, -1)
    return co


if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from fairseq.utils import move_to_cuda
    from datasets.a2d import A2D

    args = {
        "videoset_path": "/home1/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home1/user/data/A2D/Release/Annotations",
        "vocab_path": "/home1/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home1/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2D(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    args = {
        "hidden_size": 256,
        "clip_dim": 832,
        "query_dim": 300,
    }
    model = MainModel(args).cuda()
    for batch in loader:
        net_input = move_to_cuda(batch['net_input'])
        output, clip_loss = model(**net_input)
        exit(0)
